from src.datasets.dataset_creator.dsl_0 import *
from src.datasets.dataset_creator.constants import *
from src.datasets.dataset_creator.arc_types import *
import numpy as np
import random
import heapq


def create_snake_input() -> Grid:
    """
    Create an input grid for the snake game with head, body, and a single tail,
    ensuring the apple is in the same row or column as the head.
    The snake body will not wrap around the grid edges.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): The number of colors (unused in this function, but kept for compatibility)

    Returns:
        Grid: A grid with a green snake (orange head, single purple tail), and a red apple on a black background
    """

    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size
    if grid_size < 10 or grid_size > 30:
        raise ValueError("Grid size must be between 10x10 and 30x30")

    BLACK, RED, GREEN, ORANGE, PURPLE = 0, 2, 3, 4, 5  # Color indices

    def create_valid_snake():
        grid = canvas(BLACK, (grid_size, grid_size))

        # Place the head
        head = (random.randint(1, grid_size - 2), random.randint(1, grid_size - 2))  # Avoid edges
        grid = fill(grid, ORANGE, (head,))

        # Decide if apple will be in the same row or column
        is_same_row = random.choice([True, False])

        # Create the snake body
        snake_body = [head]
        for _ in range(min(9, grid_size - 3)):  # Limit body length to grid size - 3 (head, tail, and buffer)
            valid_positions = [
                pos
                for pos in get_empty_neighbors(grid, snake_body[-1])
                if is_valid_body_position(pos, head, is_same_row)
                and 0 < pos[0] < grid_size - 1
                and 0 < pos[1] < grid_size - 1  # Avoid edges
            ]
            if not valid_positions:
                return None, None, None  # Failed to create a valid snake
            new_pos = random.choice(valid_positions)
            snake_body.append(new_pos)
            grid = fill(grid, GREEN, (new_pos,))

        # Ensure there's always exactly one tail
        if len(snake_body) > 1:
            grid = fill(grid, PURPLE, (snake_body[-1],))

        return grid, snake_body, is_same_row

    # Try to create a valid snake up to 10 times
    for _ in range(10):
        grid, snake_body, is_same_row = create_valid_snake()
        if grid is not None:
            break
    else:
        raise ValueError("Failed to create a valid snake after multiple attempts")

    head = snake_body[0]

    # Place the apple
    if is_same_row:
        apple_row = head[0]
        possible_cols = [j for j in range(grid_size) if abs(j - head[1]) >= 3 and grid[apple_row][j] == BLACK]
    else:
        apple_col = head[1]
        possible_rows = [i for i in range(grid_size) if abs(i - head[0]) >= 3 and grid[i][apple_col] == BLACK]

    if (is_same_row and possible_cols) or (not is_same_row and possible_rows):
        if is_same_row:
            apple_col = random.choice(possible_cols)
            grid = fill(grid, RED, ((apple_row, apple_col),))
        else:
            apple_row = random.choice(possible_rows)
            grid = fill(grid, RED, ((apple_row, apple_col),))
    else:
        # If no valid position for apple, place it in any empty cell
        empty_cells = [(i, j) for i in range(grid_size) for j in range(grid_size) if grid[i][j] == BLACK]
        if empty_cells:
            apple_pos = random.choice(empty_cells)
            grid = fill(grid, RED, (apple_pos,))

    return np.array(grid), {}


def get_empty_neighbors(grid: Grid, pos: Tuple[int, int]) -> List[Tuple[int, int]]:
    """Find empty (black) neighboring cells."""
    grid_size = len(grid)
    neighbors = [
        ((pos[0] + 1) % grid_size, pos[1]),
        ((pos[0] - 1) % grid_size, pos[1]),
        (pos[0], (pos[1] + 1) % grid_size),
        (pos[0], (pos[1] - 1) % grid_size),
    ]
    return [n for n in neighbors if grid[n[0]][n[1]] == 0]


def is_valid_body_position(pos: Tuple[int, int], head: Tuple[int, int], is_same_row: bool) -> bool:
    """Check if the position is valid for a body part."""
    if is_same_row:
        return pos[0] != head[0]
    else:
        return pos[1] != head[1]


def create_maze_input() -> Grid:
    """
    Create an input grid for a maze puzzle with a guaranteed entrance (with agent) and path to the center.
    Works for both odd and even grid sizes.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): The number of colors (unused in this function, but kept for compatibility)

    Returns:
        Grid: A grid with a maze structure (black), agent at entrance (green), and center goal (blue) on a white background
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size

    WHITE, BLACK, GREEN, BLUE = 0, 1, 3, 6  # Color indices

    def create_maze():
        # Initialize the grid with walls
        grid = [[BLACK for _ in range(grid_size)] for _ in range(grid_size)]

        def carve_passages_from(x, y):
            directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
            random.shuffle(directions)

            for dx, dy in directions:
                nx, ny = x + dx * 2, y + dy * 2
                if 0 <= nx < grid_size and 0 <= ny < grid_size and grid[ny][nx] == BLACK:
                    grid[y + dy][x + dx] = WHITE
                    grid[ny][nx] = WHITE
                    carve_passages_from(nx, ny)

        # Start from the center
        center_x = grid_size // 2
        center_y = grid_size // 2
        grid[center_y][center_x] = WHITE
        carve_passages_from(center_x, center_y)

        # Create entrance at top-left corner and place agent there
        entrance = (0, 1)
        grid[entrance[0]][entrance[1]] = GREEN  # Place agent at entrance
        grid[1][1] = WHITE  # Ensure path from entrance

        # Mark center as goal
        grid[center_y][center_x] = BLUE

        return grid

    return np.array(create_maze()), {}


from queue import Queue


def create_city_network_input(grid_size=30):
    WHITE, BLACK, START, END = 0, 1, 2, 3

    def create_path(grid, start, end):
        queue = Queue()
        queue.put(start)
        visited = set([start])
        parent = {start: None}

        while not queue.empty():
            current = queue.get()
            if current == end:
                break

            for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                next_pos = (current[0] + dx, current[1] + dy)
                if 0 <= next_pos[0] < grid_size and 0 <= next_pos[1] < grid_size and next_pos not in visited:
                    queue.put(next_pos)
                    visited.add(next_pos)
                    parent[next_pos] = current

        if end not in parent:
            return False

        # Create the path
        current = end
        while current != start:
            if grid[current] == WHITE:
                grid[current] = BLACK
            current = parent[current]

        return True

    def create_network():
        grid = np.zeros((grid_size, grid_size), dtype=int)

        # Place start and end points with minimum distance
        min_distance = grid_size // 2
        while True:
            start = (random.randint(0, grid_size - 1), random.randint(0, grid_size - 1))
            end = (random.randint(0, grid_size - 1), random.randint(0, grid_size - 1))
            if abs(start[0] - end[0]) + abs(start[1] - end[1]) >= min_distance:
                break

        grid[start] = START
        grid[end] = END

        # Ensure a path exists between start and end
        create_path(grid, start, end)

        # Create additional sparse paths
        for _ in range(grid_size * 3):  # Increased number of paths
            x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
            length = random.randint(3, grid_size // 2)
            direction = random.choice([(0, 1), (1, 0), (1, 1), (1, -1)])

            for i in range(length):
                nx, ny = x + i * direction[0], y + i * direction[1]
                if 0 <= nx < grid_size and 0 <= ny < grid_size:
                    if grid[nx, ny] == WHITE:
                        grid[nx, ny] = BLACK

        return grid

    return create_network(), {}


def create_connect_dots_input() -> Grid:
    """
    Create an input grid for a connect the dots puzzle with interesting shapes.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): Unused, kept for compatibility

    Returns:
        Grid: A grid with numbered dots forming an interesting shape
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size
    grid = np.zeros((grid_size, grid_size), dtype=int)

    # Define some simple shapes
    shapes = [
        # Square
        lambda t: (
            int(grid_size / 2 - grid_size / 4 * np.cos(t)),
            int(grid_size / 2 - grid_size / 4 * np.sin(t)),
        ),
        # Triangle
        lambda t: (
            int(grid_size / 2 - grid_size / 3 * np.cos(t)),
            (
                int(grid_size / 2 - grid_size / 3 * np.sin(t))
                if t < 2 * np.pi / 3
                else int(grid_size / 2 + grid_size / 6)
            ),
        ),
        # Star
        lambda t: (
            int(grid_size / 2 - grid_size / 3 * (0.5 if int(5 * t / np.pi) % 2 else 1) * np.cos(t)),
            int(grid_size / 2 - grid_size / 3 * (0.5 if int(5 * t / np.pi) % 2 else 1) * np.sin(t)),
        ),
        # Circle
        lambda t: (
            int(grid_size / 2 - grid_size / 3 * np.cos(t)),
            int(grid_size / 2 - grid_size / 3 * np.sin(t)),
        ),
    ]

    # Choose a random shape
    shape = random.choice(shapes)

    # Number of dots
    num_dots = random.randint(10, 20)

    # Generate dots along the shape
    for i in range(1, num_dots + 1):
        t = 2 * np.pi * i / num_dots
        x, y = shape(t)
        x = max(0, min(x, grid_size - 1))
        y = max(0, min(y, grid_size - 1))

        # Ensure we don't overwrite existing dots
        while grid[x, y] != 0:
            x = (x + 1) % grid_size
            y = (y + 1) % grid_size

        grid[x, y] = 1

    return grid, {}


def create_flood_fill_input() -> Grid:
    """
    Create an input grid for a water-themed flood fill puzzle.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): Unused, kept for compatibility

    Returns:
        Grid: A grid representing a landscape with water and land
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    # Colors:
    # 0: Light Blue (Shallow Water)
    # 1: Dark Blue (Deep Water)
    # 2: Green (Grass)
    # 3: Yellow (Sand)
    # 4: Brown (Earth)
    # 5: Gray (Rock)
    # 9: Red (Starting point for flood)
    grid_size = row_size

    grid = np.zeros((grid_size, grid_size), dtype=int)

    # Create land masses
    for _ in range(grid_size // 2):
        x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
        radius = random.randint(2, grid_size // 4)
        land_type = random.choice([2, 3, 4, 5])
        for i in range(max(0, x - radius), min(grid_size, x + radius + 1)):
            for j in range(max(0, y - radius), min(grid_size, y + radius + 1)):
                if (i - x) ** 2 + (j - y) ** 2 <= radius**2:
                    grid[i, j] = land_type

    # Fill remaining with water
    grid[grid == 0] = random.choice([0, 1])

    # Set starting point for flood
    while True:
        start_x, start_y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
        if grid[start_x, start_y] in [0, 1]:  # Make sure it starts in water
            grid[start_x, start_y] = 9
            break

    return grid, {}


def create_checkers_input() -> np.array:
    """
    Create a Checkers board with a single capture available.

    Args:
        grid_size (int): The size of the grid (should be 8 for standard Checkers)
        colors (int): Unused, kept for compatibility

    Returns:
        np.array: A Checkers board with a single capture available
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size

    if grid_size != 8:
        grid_size = 8  # Ensure standard 8x8 Checkers board

    # 0: Empty, 1: Black piece, 2: White piece, 3: Black king, 4: White king
    board = np.zeros((8, 8), dtype=int)

    # Place a few random pieces
    for _ in range(5):
        row, col = random.randint(0, 7), random.randint(0, 7)
        board[row, col] = random.choice([1, 2])

    # Ensure there's exactly one capture available
    capture_set = False
    while not capture_set:
        board = np.zeros((8, 8), dtype=int)
        row, col = random.randint(1, 6), random.randint(1, 6)
        board[row, col] = 1  # Black piece that will capture

        # Randomly choose direction of capture
        direction = random.choice(["top_left", "top_right", "bottom_left", "bottom_right"])

        if direction == "top_left":
            board[row - 1, col - 1] = 2  # White piece to be captured
            if row - 2 >= 0 and col - 2 >= 0:
                board[row - 2, col - 2] = 0  # Ensure landing spot is empty
                capture_set = True
        elif direction == "top_right":
            board[row - 1, col + 1] = 2
            if row - 2 >= 0 and col + 2 < 8:
                board[row - 2, col + 2] = 0
                capture_set = True
        elif direction == "bottom_left":
            board[row + 1, col - 1] = 2
            if row + 2 < 8 and col - 2 >= 0:
                board[row + 2, col - 2] = 0
                capture_set = True
        else:  # bottom_right
            board[row + 1, col + 1] = 2
            if row + 2 < 8 and col + 2 < 8:
                board[row + 2, col + 2] = 0
                capture_set = True

    return board, {}


def create_light_bulb_input() -> np.array:
    """
    Create a grid with walls and light bulbs for the illumination puzzle.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): Unused, kept for compatibility

    Returns:
        np.array: A grid with walls (1), light bulbs (2), and empty spaces (0)
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size

    grid = np.zeros((grid_size, grid_size), dtype=int)

    # Add walls
    num_walls = grid_size * 2
    for _ in range(num_walls):
        x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
        grid[x, y] = 1

    # Add light bulbs
    num_bulbs = max(1, grid_size // 3)
    for _ in range(num_bulbs):
        while True:
            x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
            if grid[x, y] == 0:
                grid[x, y] = 2
                break

    return grid, {}


def create_radio_coverage_input() -> np.array:
    """
    Create a grid with radio towers and obstacles for the signal coverage puzzle.

    Args:
        grid_size (int): The size of the grid (grid_size x grid_size)
        colors (int): Unused, kept for compatibility

    Returns:
        np.array: A grid with obstacles (1), radio towers (2), and empty spaces (0)
    """
    row_size = 30  # TODO: work on getting this for varying shapes
    grid_size = row_size

    grid = np.zeros((grid_size, grid_size), dtype=int)

    # Add obstacles
    num_obstacles = grid_size * 2
    for _ in range(num_obstacles):
        x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
        grid[x, y] = 1

    # Add radio towers
    num_towers = max(1, grid_size // 4)
    for _ in range(num_towers):
        while True:
            x, y = random.randint(0, grid_size - 1), random.randint(0, grid_size - 1)
            if grid[x, y] == 0:
                grid[x, y] = 2
                break

    return grid, {}
